// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
#define EIGEN_TRIANGULARMATRIXVECTOR_H

namespace Eigen {

namespace internal {

template<typename Index,
		 int Mode,
		 typename LhsScalar,
		 bool ConjLhs,
		 typename RhsScalar,
		 bool ConjRhs,
		 int StorageOrder,
		 int Version = Specialized>
struct triangular_matrix_vector_product;

template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
struct triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, ColMajor, Version>
{
	typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
	enum
	{
		IsLower = ((Mode & Lower) == Lower),
		HasUnitDiag = (Mode & UnitDiag) == UnitDiag,
		HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag
	};
	static EIGEN_DONT_INLINE void run(Index _rows,
									  Index _cols,
									  const LhsScalar* _lhs,
									  Index lhsStride,
									  const RhsScalar* _rhs,
									  Index rhsIncr,
									  ResScalar* _res,
									  Index resIncr,
									  const RhsScalar& alpha);
};

template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
EIGEN_DONT_INLINE void
triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, ColMajor, Version>::run(
	Index _rows,
	Index _cols,
	const LhsScalar* _lhs,
	Index lhsStride,
	const RhsScalar* _rhs,
	Index rhsIncr,
	ResScalar* _res,
	Index resIncr,
	const RhsScalar& alpha)
{
	static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
	Index size = (std::min)(_rows, _cols);
	Index rows = IsLower ? _rows : (std::min)(_rows, _cols);
	Index cols = IsLower ? (std::min)(_rows, _cols) : _cols;

	typedef Map<const Matrix<LhsScalar, Dynamic, Dynamic, ColMajor>, 0, OuterStride<>> LhsMap;
	const LhsMap lhs(_lhs, rows, cols, OuterStride<>(lhsStride));
	typename conj_expr_if<ConjLhs, LhsMap>::type cjLhs(lhs);

	typedef Map<const Matrix<RhsScalar, Dynamic, 1>, 0, InnerStride<>> RhsMap;
	const RhsMap rhs(_rhs, cols, InnerStride<>(rhsIncr));
	typename conj_expr_if<ConjRhs, RhsMap>::type cjRhs(rhs);

	typedef Map<Matrix<ResScalar, Dynamic, 1>> ResMap;
	ResMap res(_res, rows);

	typedef const_blas_data_mapper<LhsScalar, Index, ColMajor> LhsMapper;
	typedef const_blas_data_mapper<RhsScalar, Index, RowMajor> RhsMapper;

	for (Index pi = 0; pi < size; pi += PanelWidth) {
		Index actualPanelWidth = (std::min)(PanelWidth, size - pi);
		for (Index k = 0; k < actualPanelWidth; ++k) {
			Index i = pi + k;
			Index s = IsLower ? ((HasUnitDiag || HasZeroDiag) ? i + 1 : i) : pi;
			Index r = IsLower ? actualPanelWidth - k : k + 1;
			if ((!(HasUnitDiag || HasZeroDiag)) || (--r) > 0)
				res.segment(s, r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s, r);
			if (HasUnitDiag)
				res.coeffRef(i) += alpha * cjRhs.coeff(i);
		}
		Index r = IsLower ? rows - pi - actualPanelWidth : pi;
		if (r > 0) {
			Index s = IsLower ? pi + actualPanelWidth : 0;
			general_matrix_vector_product<Index,
										  LhsScalar,
										  LhsMapper,
										  ColMajor,
										  ConjLhs,
										  RhsScalar,
										  RhsMapper,
										  ConjRhs,
										  BuiltIn>::run(r,
														actualPanelWidth,
														LhsMapper(&lhs.coeffRef(s, pi), lhsStride),
														RhsMapper(&rhs.coeffRef(pi), rhsIncr),
														&res.coeffRef(s),
														resIncr,
														alpha);
		}
	}
	if ((!IsLower) && cols > size) {
		general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjLhs, RhsScalar, RhsMapper, ConjRhs>::
			run(rows,
				cols - size,
				LhsMapper(&lhs.coeffRef(0, size), lhsStride),
				RhsMapper(&rhs.coeffRef(size), rhsIncr),
				_res,
				resIncr,
				alpha);
	}
}

template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
struct triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, RowMajor, Version>
{
	typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
	enum
	{
		IsLower = ((Mode & Lower) == Lower),
		HasUnitDiag = (Mode & UnitDiag) == UnitDiag,
		HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag
	};
	static EIGEN_DONT_INLINE void run(Index _rows,
									  Index _cols,
									  const LhsScalar* _lhs,
									  Index lhsStride,
									  const RhsScalar* _rhs,
									  Index rhsIncr,
									  ResScalar* _res,
									  Index resIncr,
									  const ResScalar& alpha);
};

template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
EIGEN_DONT_INLINE void
triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, RowMajor, Version>::run(
	Index _rows,
	Index _cols,
	const LhsScalar* _lhs,
	Index lhsStride,
	const RhsScalar* _rhs,
	Index rhsIncr,
	ResScalar* _res,
	Index resIncr,
	const ResScalar& alpha)
{
	static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
	Index diagSize = (std::min)(_rows, _cols);
	Index rows = IsLower ? _rows : diagSize;
	Index cols = IsLower ? diagSize : _cols;

	typedef Map<const Matrix<LhsScalar, Dynamic, Dynamic, RowMajor>, 0, OuterStride<>> LhsMap;
	const LhsMap lhs(_lhs, rows, cols, OuterStride<>(lhsStride));
	typename conj_expr_if<ConjLhs, LhsMap>::type cjLhs(lhs);

	typedef Map<const Matrix<RhsScalar, Dynamic, 1>> RhsMap;
	const RhsMap rhs(_rhs, cols);
	typename conj_expr_if<ConjRhs, RhsMap>::type cjRhs(rhs);

	typedef Map<Matrix<ResScalar, Dynamic, 1>, 0, InnerStride<>> ResMap;
	ResMap res(_res, rows, InnerStride<>(resIncr));

	typedef const_blas_data_mapper<LhsScalar, Index, RowMajor> LhsMapper;
	typedef const_blas_data_mapper<RhsScalar, Index, RowMajor> RhsMapper;

	for (Index pi = 0; pi < diagSize; pi += PanelWidth) {
		Index actualPanelWidth = (std::min)(PanelWidth, diagSize - pi);
		for (Index k = 0; k < actualPanelWidth; ++k) {
			Index i = pi + k;
			Index s = IsLower ? pi : ((HasUnitDiag || HasZeroDiag) ? i + 1 : i);
			Index r = IsLower ? k + 1 : actualPanelWidth - k;
			if ((!(HasUnitDiag || HasZeroDiag)) || (--r) > 0)
				res.coeffRef(i) +=
					alpha * (cjLhs.row(i).segment(s, r).cwiseProduct(cjRhs.segment(s, r).transpose())).sum();
			if (HasUnitDiag)
				res.coeffRef(i) += alpha * cjRhs.coeff(i);
		}
		Index r = IsLower ? pi : cols - pi - actualPanelWidth;
		if (r > 0) {
			Index s = IsLower ? 0 : pi + actualPanelWidth;
			general_matrix_vector_product<Index,
										  LhsScalar,
										  LhsMapper,
										  RowMajor,
										  ConjLhs,
										  RhsScalar,
										  RhsMapper,
										  ConjRhs,
										  BuiltIn>::run(actualPanelWidth,
														r,
														LhsMapper(&lhs.coeffRef(pi, s), lhsStride),
														RhsMapper(&rhs.coeffRef(s), rhsIncr),
														&res.coeffRef(pi),
														resIncr,
														alpha);
		}
	}
	if (IsLower && rows > diagSize) {
		general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjLhs, RhsScalar, RhsMapper, ConjRhs>::
			run(rows - diagSize,
				cols,
				LhsMapper(&lhs.coeffRef(diagSize, 0), lhsStride),
				RhsMapper(&rhs.coeffRef(0), rhsIncr),
				&res.coeffRef(diagSize),
				resIncr,
				alpha);
	}
}

/***************************************************************************
 * Wrapper to product_triangular_vector
 ***************************************************************************/

template<int Mode, int StorageOrder>
struct trmv_selector;

} // end namespace internal

namespace internal {

template<int Mode, typename Lhs, typename Rhs>
struct triangular_product_impl<Mode, true, Lhs, false, Rhs, true>
{
	template<typename Dest>
	static void run(Dest& dst, const Lhs& lhs, const Rhs& rhs, const typename Dest::Scalar& alpha)
	{
		eigen_assert(dst.rows() == lhs.rows() && dst.cols() == rhs.cols());

		internal::trmv_selector<Mode, (int(internal::traits<Lhs>::Flags) & RowMajorBit) ? RowMajor : ColMajor>::run(
			lhs, rhs, dst, alpha);
	}
};

template<int Mode, typename Lhs, typename Rhs>
struct triangular_product_impl<Mode, false, Lhs, true, Rhs, false>
{
	template<typename Dest>
	static void run(Dest& dst, const Lhs& lhs, const Rhs& rhs, const typename Dest::Scalar& alpha)
	{
		eigen_assert(dst.rows() == lhs.rows() && dst.cols() == rhs.cols());

		Transpose<Dest> dstT(dst);
		internal::trmv_selector<(Mode & (UnitDiag | ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
								(int(internal::traits<Rhs>::Flags) & RowMajorBit) ? ColMajor
																				  : RowMajor>::run(rhs.transpose(),
																								   lhs.transpose(),
																								   dstT,
																								   alpha);
	}
};

} // end namespace internal

namespace internal {

// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.

template<int Mode>
struct trmv_selector<Mode, ColMajor>
{
	template<typename Lhs, typename Rhs, typename Dest>
	static void run(const Lhs& lhs, const Rhs& rhs, Dest& dest, const typename Dest::Scalar& alpha)
	{
		typedef typename Lhs::Scalar LhsScalar;
		typedef typename Rhs::Scalar RhsScalar;
		typedef typename Dest::Scalar ResScalar;
		typedef typename Dest::RealScalar RealScalar;

		typedef internal::blas_traits<Lhs> LhsBlasTraits;
		typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
		typedef internal::blas_traits<Rhs> RhsBlasTraits;
		typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;

		typedef Map<Matrix<ResScalar, Dynamic, 1>,
					EIGEN_PLAIN_ENUM_MIN(AlignedMax, internal::packet_traits<ResScalar>::size)>
			MappedDest;

		typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
		typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);

		LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
		RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
		ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;

		enum
		{
			// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
			// on, the other hand it is good for the cache to pack the vector anyways...
			EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime == 1,
			ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
			MightCannotUseDest = (Dest::InnerStrideAtCompileTime != 1) || ComplexByReal
		};

		gemv_static_vector_if<ResScalar, Dest::SizeAtCompileTime, Dest::MaxSizeAtCompileTime, MightCannotUseDest>
			static_dest;

		bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha) == RealScalar(0));
		bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;

		RhsScalar compatibleAlpha = get_factor<ResScalar, RhsScalar>::run(actualAlpha);

		ei_declare_aligned_stack_constructed_variable(
			ResScalar, actualDestPtr, dest.size(), evalToDest ? dest.data() : static_dest.data());

		if (!evalToDest) {
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
			Index size = dest.size();
			EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif
			if (!alphaIsCompatible) {
				MappedDest(actualDestPtr, dest.size()).setZero();
				compatibleAlpha = RhsScalar(1);
			} else
				MappedDest(actualDestPtr, dest.size()) = dest;
		}

		internal::triangular_matrix_vector_product<Index,
												   Mode,
												   LhsScalar,
												   LhsBlasTraits::NeedToConjugate,
												   RhsScalar,
												   RhsBlasTraits::NeedToConjugate,
												   ColMajor>::run(actualLhs.rows(),
																  actualLhs.cols(),
																  actualLhs.data(),
																  actualLhs.outerStride(),
																  actualRhs.data(),
																  actualRhs.innerStride(),
																  actualDestPtr,
																  1,
																  compatibleAlpha);

		if (!evalToDest) {
			if (!alphaIsCompatible)
				dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
			else
				dest = MappedDest(actualDestPtr, dest.size());
		}

		if (((Mode & UnitDiag) == UnitDiag) && (lhs_alpha != LhsScalar(1))) {
			Index diagSize = (std::min)(lhs.rows(), lhs.cols());
			dest.head(diagSize) -= (lhs_alpha - LhsScalar(1)) * rhs.head(diagSize);
		}
	}
};

template<int Mode>
struct trmv_selector<Mode, RowMajor>
{
	template<typename Lhs, typename Rhs, typename Dest>
	static void run(const Lhs& lhs, const Rhs& rhs, Dest& dest, const typename Dest::Scalar& alpha)
	{
		typedef typename Lhs::Scalar LhsScalar;
		typedef typename Rhs::Scalar RhsScalar;
		typedef typename Dest::Scalar ResScalar;

		typedef internal::blas_traits<Lhs> LhsBlasTraits;
		typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
		typedef internal::blas_traits<Rhs> RhsBlasTraits;
		typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
		typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;

		typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
		typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);

		LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
		RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
		ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;

		enum
		{
			DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime == 1
		};

		gemv_static_vector_if<RhsScalar,
							  ActualRhsTypeCleaned::SizeAtCompileTime,
							  ActualRhsTypeCleaned::MaxSizeAtCompileTime,
							  !DirectlyUseRhs>
			static_rhs;

		ei_declare_aligned_stack_constructed_variable(RhsScalar,
													  actualRhsPtr,
													  actualRhs.size(),
													  DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data())
																	 : static_rhs.data());

		if (!DirectlyUseRhs) {
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
			Index size = actualRhs.size();
			EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif
			Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
		}

		internal::triangular_matrix_vector_product<Index,
												   Mode,
												   LhsScalar,
												   LhsBlasTraits::NeedToConjugate,
												   RhsScalar,
												   RhsBlasTraits::NeedToConjugate,
												   RowMajor>::run(actualLhs.rows(),
																  actualLhs.cols(),
																  actualLhs.data(),
																  actualLhs.outerStride(),
																  actualRhsPtr,
																  1,
																  dest.data(),
																  dest.innerStride(),
																  actualAlpha);

		if (((Mode & UnitDiag) == UnitDiag) && (lhs_alpha != LhsScalar(1))) {
			Index diagSize = (std::min)(lhs.rows(), lhs.cols());
			dest.head(diagSize) -= (lhs_alpha - LhsScalar(1)) * rhs.head(diagSize);
		}
	}
};

} // end namespace internal

} // end namespace Eigen

#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
